#include "db_mgr.h"

#include <sstream>

#include <cppconn/driver.h>
#include <cppconn/exception.h>
#include <cppconn/warning.h>
#include <cppconn/metadata.h>
#include <cppconn/prepared_statement.h>
#include <cppconn/resultset.h>
#include <cppconn/resultset_metadata.h>
#include <cppconn/statement.h>
#include <mysql_driver.h>
#include <mysql_connection.h>

#include <common/log/log.h>

DBMgr::~DBMgr()
{
}

bool DBMgr::init()
{
    if (!cfg_.init())
    {
        LOG_ERROR("DB CONFIG INIT FAILED.");
        return false;
    }

    try
    {
        sql::Driver * driver = sql::mysql::get_driver_instance();
        if (!driver)
        {
            LOG_ERROR("get_mysql_driver_instance FAILED.");
            return false;
        }
        const DBConfigInfo& dbInfo = cfg_.getDBInfo();
        std::stringstream ss;
        ss << dbInfo.addr << ":" << dbInfo.port;
        conn_.reset(driver->connect(ss.str().c_str(), dbInfo.user.c_str(), dbInfo.passwd.c_str()));
        if (!conn_)
        {
            LOG_ERROR("No sql connected.");
            return false;
        }

        // create `login_server` database
        std::unique_ptr<sql::Statement> stmt(conn_->createStatement());
        stmt->execute(dbInfo.sqlSchema);

        // create `tables` table
        conn_->setSchema(dbInfo.dbname);
        stmt->execute(dbInfo.sqlVersion);

        // got every table version info from `tables` table
        std::unique_ptr<sql::ResultSet> res(stmt->executeQuery(" select tablename, version from `tables` "));
        std::map<std::string, uint32_t> versionInfos;
        while (res->next())
        {
            versionInfos[res->getString(1)] = res->getInt(2);
        }

        // create and update all table defined in db.xml
        auto getTableVersion = [&versionInfos](const std::string& tableName)->int32_t
        {
            auto it = versionInfos.find(tableName);
            return it == versionInfos.end() ? 0 : it->second;
        };

        for (auto& it : dbInfo.sqlTables)
        {
            int32_t version = getTableVersion(it.first);

            for (auto& itv : it.second)
            {
                if (itv.first <= version)
                {
                    continue;
                }

                stmt->executeUpdate(itv.second);

                std::unique_ptr<sql::PreparedStatement> prepStmt(conn_->prepareStatement("replace into `tables` (`tablename`,`version`) values ( ?, ? )"));
                prepStmt->setString(1, it.first);
                prepStmt->setInt(2, itv.first);

                if (prepStmt->executeUpdate() <= 0)
                {
                    LOG_ERROR("Update table version failed, table: " << it.first << " version: " << itv.first);
                    return false;
                }
            }
        }
    }
    catch (sql::SQLException &e)
    {
        LOG_ERROR("# ERR: " << e.what()
                  << " (MySQL error code: " << e.getErrorCode()
                  << ", SQLState: " << e.getSQLState() << " )");
        return false;
    }
    return true;
}
